use std::ops::Deref;

use crate::scene::Vertex;

pub struct Shader<'a> {
    shader: wgpu::ShaderModule,
    buffers: Vec<wgpu::VertexBufferLayout<'a>>,
    targets: Vec<Option<wgpu::ColorTargetState>>,
}

impl<'a> Shader<'a> {
    pub fn new(ctx: &super::Context, desc: wgpu::ShaderModuleDescriptor) -> Self {
        let shader = ctx.device.create_shader_module(desc);
        Self {
            shader,
            buffers: vec![],
            targets: vec![],
        }
    }

    pub fn with_vertex_screen<V: Vertex>(
        ctx: &super::Context,
        desc: wgpu::ShaderModuleDescriptor,
    ) -> Self {
        let shader = ctx.device.create_shader_module(desc);
        Self {
            shader,
            buffers: vec![V::layout()],
            targets: vec![Some(wgpu::ColorTargetState {
                format: ctx.config.format,
                blend: None,
                write_mask: wgpu::ColorWrites::ALL,
            })],
        }
    }

    pub fn from_src<V: Vertex>(ctx: &super::Context, src: &str) -> Self {
        Self::with_vertex_screen::<V>(
            ctx,
            wgpu::ShaderModuleDescriptor {
                label: None,
                source: wgpu::ShaderSource::Wgsl(src.into()),
            },
        )
    }

    pub fn add_buffer_layout<V: Vertex>(&mut self) {
        self.buffers.push(V::layout())
    }

    pub fn add_color_target(&mut self, color_target: wgpu::ColorTargetState) {
        self.targets.push(Some(color_target));
    }

    pub fn add_screen_color_target(&mut self, ctx: &super::Context) {
        self.add_color_target(wgpu::ColorTargetState {
            format: ctx.config.format,
            blend: None,
            write_mask: wgpu::ColorWrites::ALL,
        });
    }

    pub fn clear_buffer_layouts(&mut self) {
        self.buffers.clear();
    }

    pub fn clear_targets(&mut self) {
        self.targets.clear();
    }

    pub fn vertex(&'a self) -> wgpu::VertexState<'a> {
        wgpu::VertexState {
            module: &self.shader,
            entry_point: "vert",
            buffers: &self.buffers,
        }
    }

    pub fn fragment(&'a self) -> wgpu::FragmentState<'a> {
        wgpu::FragmentState {
            module: &self.shader,
            entry_point: "frag",
            targets: &self.targets,
        }
    }
}

impl<'a> Deref for Shader<'a> {
    type Target = wgpu::ShaderModule;

    fn deref(&self) -> &Self::Target {
        &self.shader
    }
}
